from maix import nn, camera, display, image, time, touchscreen, app
import os
from maix import gpio, pinmap

# Define the pins for the LEDs
pinmap.set_pin_function("A15", "GPIOA15")
led2 = gpio.GPIO("GPIOA15", gpio.Mode.OUT)  # LED for known face
led2.value(0)  # Turn off initially
pinmap.set_pin_function("A16", "GPIOA16")
led1 = gpio.GPIO("GPIOA16", gpio.Mode.OUT)  # LED for unknown face
led1.value(0)  # Turn off initially

# Global Variables
pressed_flag = [False, False, False]
learn_id = 0
person_lookup = {}  # To store Person_ID to Human-readable Name mapping
recognizer = None  # Face recognizer object

# Load the person lookup table from the file
def load_person_lookup():
    global person_lookup
    if os.path.exists("/root/person.txt"):
        with open("/root/person.txt", "r") as f:
            for line in f:
                parts = line.strip().split("=")
                if len(parts) == 2:
                    person_lookup[parts[0]] = parts[1].strip('"')  # Remove quotes and spaces
        print("Person lookup table loaded.")
    else:
        print("No lookup table found. Starting fresh.")

# Save the person lookup table to the file
def save_person_lookup():
    with open("/root/person.txt", "w") as f:
        for person, name in person_lookup.items():
            f.write(f"{person}={name}\n")
    print("Person lookup saved to /root/person.txt")

# Initialize the recognizer and load previously saved faces
def initialize_recognizer():
    global recognizer
    recognizer = nn.FaceRecognizer(
        detect_model="/root/models/yolov8n_face.mud",
        feature_model="/root/models/insghtface_webface_r50.mud",
        dual_buff=True
    )
    if os.path.exists("/root/faces.bin"):
        print("Loading saved faces from /root/faces.bin...")
        recognizer.load_faces("/root/faces.bin")
    else:
        print("No saved faces found. Starting fresh.")

# Main Function
def main(disp):
    global pressed_flag, learn_id, person_lookup
    cam = camera.Camera(recognizer.input_width(), recognizer.input_height(), recognizer.input_format())
    ts = touchscreen.TouchScreen()

    # Button positions
    back_btn_pos = (0, 0, 70, 30)
    learn_btn_pos = (0, recognizer.input_height() - 30, 60, 30)
    clear_btn_pos = (recognizer.input_width() - 60, recognizer.input_height() - 30, 60, 30)

    # Resize button positions for display
    back_btn_disp_pos = image.resize_map_pos(cam.width(), cam.height(), disp.width(), disp.height(), image.Fit.FIT_CONTAIN, *back_btn_pos)
    learn_btn_disp_pos = image.resize_map_pos(cam.width(), cam.height(), disp.width(), disp.height(), image.Fit.FIT_CONTAIN, *learn_btn_pos)
    clear_btn_disp_pos = image.resize_map_pos(cam.width(), cam.height(), disp.width(), disp.height(), image.Fit.FIT_CONTAIN, *clear_btn_pos)

    # Helper functions
    def draw_btns(img):
        img.draw_rect(*back_btn_pos, image.Color.from_rgb(255, 255, 255), 2)
        img.draw_string(back_btn_pos[0] + 4, back_btn_pos[1] + 8, "< back", image.COLOR_WHITE)
        img.draw_rect(*learn_btn_pos, image.Color.from_rgb(255, 255, 255), 2)
        img.draw_string(learn_btn_pos[0] + 4, learn_btn_pos[1] + 8, "learn", image.COLOR_WHITE)
        img.draw_rect(*clear_btn_pos, image.Color.from_rgb(255, 255, 255), 2)
        img.draw_string(clear_btn_pos[0] + 4, clear_btn_pos[1] + 8, "clear", image.COLOR_WHITE)

    def on_touch(x, y, pressed):
        if pressed:
            if is_in_button(x, y, back_btn_disp_pos):
                return False, False, True
            elif is_in_button(x, y, learn_btn_disp_pos):
                return True, False, False
            elif is_in_button(x, y, clear_btn_disp_pos):
                return False, True, False
        return False, False, False

    def is_in_button(x, y, btn_pos):
        return btn_pos[0] < x < btn_pos[0] + btn_pos[2] and btn_pos[1] < y < btn_pos[1] + btn_pos[3]

    # Main loop
    while not app.need_exit():
        x, y, pressed = ts.read()
        learn, clear, back = on_touch(x, y, pressed)
        if back:
            break
        elif clear:
            print("Clearing all faces...")
            recognizer.reset()
            person_lookup.clear()
            save_person_lookup()
            recognizer.save_faces("/root/faces.bin")
            print("All faces cleared.")
        img = cam.read()
        faces = recognizer.recognize(img, 0.5, 0.45, 0.85, learn, learn)
        led1.value(0)  # Default: Turn off LED1
        led2.value(0)  # Default: Turn off LED2

        for obj in faces:
            if obj.class_id != 0:  # Known face
                led1.value(1)  # Turn on LED1
                color = image.COLOR_GREEN
            else:  # Unknown face
                led2.value(1)  # Turn on LED2
                color = image.COLOR_RED

            img.draw_rect(obj.x, obj.y, obj.w, obj.h, color)
            person_id = f"Person_{obj.class_id}"
            name = person_lookup.get(person_id, person_id)
            img.draw_string(obj.x, obj.y - 10, f"{obj.class_id + 1}: {name}", color)

            if learn and obj.class_id == 0:  # New face
                person_lookup[person_id] = f"Person_{learn_id}"
                recognizer.add_face(obj, person_lookup[person_id])
                learn_id += 1
                save_person_lookup()
                recognizer.save_faces("/root/faces.bin")

        draw_btns(img)
        disp.show(img)

# Run Program
disp = display.Display()
try:
    load_person_lookup()
    initialize_recognizer()
    main(disp)
except Exception as e:
    print(f"Error occurred: {e}")
    img = image.Image(disp.width(), disp.height())
    img.draw_string(0, 0, str(e), image.COLOR_WHITE)
    disp.show(img)
    while not app.need_exit():
        time.sleep_ms(100)